{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "36972001-cbed-46f6-8bed-e57feec3bbd4",
   "metadata": {},
   "source": [
    "# 使用领域（私有）数据微调 ChatGLM3\n",
    "\n",
    "生成带有 epoch 和 timestamp 的模型文件"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "f2386615-d1d6-40c9-a014-b2bce85838ab",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "PyTorch built with:\n",
      "  - GCC 9.3\n",
      "  - C++ Version: 201703\n",
      "  - Intel(R) oneAPI Math Kernel Library Version 2022.2-Product Build 20220804 for Intel(R) 64 architecture applications\n",
      "  - Intel(R) MKL-DNN v3.3.2 (Git Hash 2dc95a2ad0841e29db8b22fbccaf3e5da7992b01)\n",
      "  - OpenMP 201511 (a.k.a. OpenMP 4.5)\n",
      "  - LAPACK is enabled (usually provided by MKL)\n",
      "  - NNPACK is enabled\n",
      "  - CPU capability usage: AVX512\n",
      "  - CUDA Runtime 12.1\n",
      "  - NVCC architecture flags: -gencode;arch=compute_50,code=sm_50;-gencode;arch=compute_60,code=sm_60;-gencode;arch=compute_70,code=sm_70;-gencode;arch=compute_75,code=sm_75;-gencode;arch=compute_80,code=sm_80;-gencode;arch=compute_86,code=sm_86;-gencode;arch=compute_90,code=sm_90\n",
      "  - CuDNN 8.9.2\n",
      "  - Magma 2.6.1\n",
      "  - Build settings: BLAS_INFO=mkl, BUILD_TYPE=Release, CUDA_VERSION=12.1, CUDNN_VERSION=8.9.2, CXX_COMPILER=/opt/rh/devtoolset-9/root/usr/bin/c++, CXX_FLAGS= -D_GLIBCXX_USE_CXX11_ABI=0 -fabi-version=11 -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -DNDEBUG -DUSE_KINETO -DLIBKINETO_NOROCTRACER -DUSE_FBGEMM -DUSE_QNNPACK -DUSE_PYTORCH_QNNPACK -DUSE_XNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -O2 -fPIC -Wall -Wextra -Werror=return-type -Werror=non-virtual-dtor -Werror=bool-operation -Wnarrowing -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wno-unused-parameter -Wno-unused-function -Wno-unused-result -Wno-strict-overflow -Wno-strict-aliasing -Wno-stringop-overflow -Wsuggest-override -Wno-psabi -Wno-error=pedantic -Wno-error=old-style-cast -Wno-missing-braces -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Wno-stringop-overflow, LAPACK_INFO=mkl, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, TORCH_VERSION=2.2.2, USE_CUDA=ON, USE_CUDNN=ON, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=ON, USE_MKLDNN=ON, USE_MPI=OFF, USE_NCCL=1, USE_NNPACK=ON, USE_OPENMP=ON, USE_ROCM=OFF, USE_ROCM_KERNEL_ASSERT=OFF, \n",
      " _CudaDeviceProperties(name='Tesla T4', major=7, minor=5, total_memory=14930MB, multi_processor_count=40)\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "\n",
    "print(torch.__config__.show(), torch.cuda.get_device_properties(0))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "deb6ec09-195a-49d0-aa2e-017a983df366",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "os.environ['HF_HOME'] = '/mnt/huggingface/hf'\n",
    "os.environ['HF_HUB_CACHE'] = '/mnt/huggingface/hf/hub'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "280d5f7b-dada-49e1-81ee-d28a32900423",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 定义全局变量和参数\n",
    "model_name_or_path = 'THUDM/chatglm3-6b'  # 模型ID或本地路径\n",
    "# train_data_path = 'data/zhouyi_dataset_handmade.csv'    # 训练数据路径\n",
    "train_data_path = 'data/zhouyi_dataset_20240419_100810_review.csv'    # 训练数据路径(批量生成数据集）\n",
    "eval_data_path = None                     # 验证数据路径，如果没有则设置为None\n",
    "seed = 8                                 # 随机种子\n",
    "max_input_length = 512                    # 输入的最大长度\n",
    "max_output_length = 1536                  # 输出的最大长度\n",
    "lora_rank = 16                             # LoRA秩\n",
    "lora_alpha = 32                           # LoRA alpha值\n",
    "lora_dropout = 0.05                       # LoRA Dropout率\n",
    "prompt_text = ''                          # 所有数据前的指令文本"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "461e2caf-4dd4-4d4d-a8f6-7bfd674d5754",
   "metadata": {},
   "source": [
    "## 数据处理"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "2c4d167c-0f39-4b11-8fad-ec608c2b3c3e",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/root/miniconda3/envs/langchain/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "DatasetDict({\n",
      "    train: Dataset({\n",
      "        features: ['content', 'summary'],\n",
      "        num_rows: 200\n",
      "    })\n",
      "})\n"
     ]
    }
   ],
   "source": [
    "from datasets import load_dataset\n",
    "\n",
    "dataset = load_dataset(\"csv\", data_files=train_data_path)\n",
    "print(dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "4d02d0c6-fc78-429c-9193-1453a91aed81",
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import ClassLabel, Sequence\n",
    "import random\n",
    "import pandas as pd\n",
    "from IPython.display import display, HTML\n",
    "\n",
    "def show_random_elements(dataset, num_examples=10):\n",
    "    assert num_examples <= len(dataset), \"Can't pick more elements than there are in the dataset.\"\n",
    "    picks = []\n",
    "    for _ in range(num_examples):\n",
    "        pick = random.randint(0, len(dataset)-1)\n",
    "        while pick in picks:\n",
    "            pick = random.randint(0, len(dataset)-1)\n",
    "        picks.append(pick)\n",
    "    \n",
    "    df = pd.DataFrame(dataset[picks])\n",
    "    for column, typ in dataset.features.items():\n",
    "        if isinstance(typ, ClassLabel):\n",
    "            df[column] = df[column].transform(lambda i: typ.names[i])\n",
    "        elif isinstance(typ, Sequence) and isinstance(typ.feature, ClassLabel):\n",
    "            df[column] = df[column].transform(lambda x: [typ.feature.names[i] for i in x])\n",
    "    display(HTML(df.to_html()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "4787d6f7-27ab-4822-b64e-2794bf52ec91",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>content</th>\n",
       "      <th>summary</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>讼卦和启蒙教育之间有何联系？</td>\n",
       "      <td>在周易中，讼卦是一个寓意深远的卦象。它由上卦乾（天）和下卦坎（水）组成，象征着天与水相违行，即天高水深，事理乖舛。讼卦代表着争论与争讼之象，虽有利可图，但需要警惕戒惧。初期事情会有中间的顺利，但后来会出现凶险。值得注意的是，占得此卦有利于会见贵族王公，但不利于涉水渡河。\\n\\n讼卦在象辞中呈现出天高水深的局面，预示着事事都不宜强求，需要慎之又慎，没有从容处世的态度，很容易陷入争论纠纷。因此君子观此卦象，以杜绝争讼为意，从而在谋事之初必须慎之又慎。整体而言，讼卦提示人们应当谨慎处理事务，并避免与他人在争执中陷入对抗。在事业、经商、求名、婚恋和决策等方面，皆需要小心提防，退让让人，化解争端，方能避免不测之灾，最终得到吉祥和利益。</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>周易中师卦的解释是什么？</td>\n",
       "      <td>在周易中，师卦是一个极具深意的卦象，它由两个异卦相叠组成：下卦坎（水）和上卦坤（地）。这一卦象代表“师”，即军队，寓意着兵力和农力的结合。在这里，坎卦象征着水和险难，而坤卦象征着地和顺从，暗示着通过将军事力量安置于民间，可以在必要时顺利调动。\\n\\n师卦的核心哲学是：虽然兵力代表着危险和战争，但其使用应当是圣人不得已而为之的最后手段。在正确的情况下，军事力量可以顺应形势，将危险转化为吉祥。因此，在军事策略上，此卦象征着出征将会顺利，无灾祸。\\n\\n师卦紧随讼卦（争讼卦），在《序卦》中解释为“讼必有众起，故受之以师”。这意味着争端激化至众多人群的参与，形成了类似军队的集体力量。在卜筮中，师卦出现，意味着困难重重，需要包容他人，艰苦努力，摒除一切困难。此外，在运势、事业、经商、求名、婚恋等方面也有着具体的解释和建议。</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>比卦代表什么？</td>\n",
       "      <td>在周易中，比卦代表相亲相辅、和乐之象。下卦为坤（地），上卦为坎（水），水附大地，地纳河海，相互依赖，亲密无间。这是一个吉利的卦，预示着长期如此，就会无灾祸。同时再卜筮，仍然大吉大利。不愿臣服的邦国来朝，迟迟不来者有难。在运势上，表示平顺，可得贵人提拔，但需要以诚实、信任的态度去做事。在事业和经商上，需要与他人密切合作，真诚交往，讲究商业道德。对于婚恋和家宅，预示着美好姻缘、百年好合。总体而言，比卦传达了待人厚道、诚实宽厚的重要性，并且提醒着要选择朋友时谨慎选择，避免与品行不端的人结交。</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>周易中需卦的象征意义是什么？</td>\n",
       "      <td>在周易中，需卦代表抓到俘虏，预示着大吉大利，适宜进行卜问。同时，有利于涉水渡河。卦象上卦为坎，代表云；下卦为乾，代表天。云浮聚于天上，待时降雨，正是需卦的卦象。君子观此卦象，可以安享宴饮，等待时机而行。需卦的核心哲学是：时机尚未成熟，需要耐心等待，切勿冒险，欲速不达。在事业、经商、求名等方面，都需要审时度势，耐心等待，充满自信且谨慎行事。而在婚恋关系中，更要慎重从事，培养感情，以诚实、热情相待。总之，需卦建议着重审时度势，耐心等待，并在时机成熟后稳步前行。</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>乾卦是什么？</td>\n",
       "      <td>在周易中，乾卦是六十四卦之首，象征着天和龙（德才的君子），代表着纯粹的阳和健，意味着兴盛和强健。乾卦的占卜结果是大吉大利，代表着吉利和贞正。《象辞》中说，“天道刚健，运行不已”，君子观此卦象，从而以天为法，自强不息。这一卦象教导人遵守天道的德行，表明事业如日中天，但也需警惕盛极必衰的道理。在事业、经商、求名、婚恋和决策等方面，都需要保持积极谨慎的态度，坚持正直、公允的实质，修养德行，积累知识，自强不息，方能克服困难，取得成功。</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_random_elements(dataset[\"train\"], num_examples=5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "b99c93ee-dfaf-4f38-b499-202accf20a1d",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoTokenizer\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_name_or_path,\n",
    "                                          trust_remote_code=True,\n",
    "                                          revision='b098244')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "c680f5af-4c73-4f40-adfa-0e404e10d244",
   "metadata": {},
   "outputs": [],
   "source": [
    "# tokenize_func 函数\n",
    "def tokenize_func(example, tokenizer, ignore_label_id=-100):\n",
    "    \"\"\"\n",
    "    对单个数据样本进行tokenize处理。\n",
    "\n",
    "    参数:\n",
    "    example (dict): 包含'content'和'summary'键的字典，代表训练数据的一个样本。\n",
    "    tokenizer (transformers.PreTrainedTokenizer): 用于tokenize文本的tokenizer。\n",
    "    ignore_label_id (int, optional): 在label中用于填充的忽略ID，默认为-100。\n",
    "\n",
    "    返回:\n",
    "    dict: 包含'tokenized_input_ids'和'labels'的字典，用于模型训练。\n",
    "    \"\"\"\n",
    "\n",
    "    # 构建问题文本\n",
    "    question = prompt_text + example['content']\n",
    "    if example.get('input', None) and example['input'].strip():\n",
    "        question += f'\\n{example[\"input\"]}'\n",
    "\n",
    "    # 构建答案文本\n",
    "    answer = example['summary']\n",
    "\n",
    "    # 对问题和答案文本进行tokenize处理\n",
    "    q_ids = tokenizer.encode(text=question, add_special_tokens=False)\n",
    "    a_ids = tokenizer.encode(text=answer, add_special_tokens=False)\n",
    "\n",
    "    # 如果tokenize后的长度超过最大长度限制，则进行截断\n",
    "    if len(q_ids) > max_input_length - 2:  # 保留空间给gmask和bos标记\n",
    "        q_ids = q_ids[:max_input_length - 2]\n",
    "    if len(a_ids) > max_output_length - 1:  # 保留空间给eos标记\n",
    "        a_ids = a_ids[:max_output_length - 1]\n",
    "\n",
    "    # 构建模型的输入格式\n",
    "    input_ids = tokenizer.build_inputs_with_special_tokens(q_ids, a_ids)\n",
    "    question_length = len(q_ids) + 2  # 加上gmask和bos标记\n",
    "\n",
    "    # 构建标签，对于问题部分的输入使用ignore_label_id进行填充\n",
    "    labels = [ignore_label_id] * question_length + input_ids[question_length:]\n",
    "\n",
    "    return {'input_ids': input_ids, 'labels': labels}\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "b99af28d-86b1-41c1-84da-2d2a4efc200f",
   "metadata": {},
   "outputs": [],
   "source": [
    "column_names = dataset['train'].column_names\n",
    "tokenized_dataset = dataset['train'].map(\n",
    "    lambda example: tokenize_func(example, tokenizer),\n",
    "    batched=False, \n",
    "    remove_columns=column_names\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "b749ba67-e890-4a42-a652-1eacf030ad47",
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenized_dataset = tokenized_dataset.shuffle(seed=seed)\n",
    "tokenized_dataset = tokenized_dataset.flatten_indices()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "5fb9d942-e138-4ff4-a475-65d6e4f7d0c3",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from typing import List, Dict, Optional\n",
    "\n",
    "# DataCollatorForChatGLM 类\n",
    "class DataCollatorForChatGLM:\n",
    "    \"\"\"\n",
    "    用于处理批量数据的DataCollator，尤其是在使用 ChatGLM 模型时。\n",
    "\n",
    "    该类负责将多个数据样本（tokenized input）合并为一个批量，并在必要时进行填充(padding)。\n",
    "\n",
    "    属性:\n",
    "    pad_token_id (int): 用于填充(padding)的token ID。\n",
    "    max_length (int): 单个批量数据的最大长度限制。\n",
    "    ignore_label_id (int): 在标签中用于填充的ID。\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self, pad_token_id: int, max_length: int = 2048, ignore_label_id: int = -100):\n",
    "        \"\"\"\n",
    "        初始化DataCollator。\n",
    "\n",
    "        参数:\n",
    "        pad_token_id (int): 用于填充(padding)的token ID。\n",
    "        max_length (int): 单个批量数据的最大长度限制。\n",
    "        ignore_label_id (int): 在标签中用于填充的ID，默认为-100。\n",
    "        \"\"\"\n",
    "        self.pad_token_id = pad_token_id\n",
    "        self.ignore_label_id = ignore_label_id\n",
    "        self.max_length = max_length\n",
    "\n",
    "    def __call__(self, batch_data: List[Dict[str, List]]) -> Dict[str, torch.Tensor]:\n",
    "        \"\"\"\n",
    "        处理批量数据。\n",
    "\n",
    "        参数:\n",
    "        batch_data (List[Dict[str, List]]): 包含多个样本的字典列表。\n",
    "\n",
    "        返回:\n",
    "        Dict[str, torch.Tensor]: 包含处理后的批量数据的字典。\n",
    "        \"\"\"\n",
    "        # 计算批量中每个样本的长度\n",
    "        len_list = [len(d['input_ids']) for d in batch_data]\n",
    "        batch_max_len = max(len_list)  # 找到最长的样本长度\n",
    "\n",
    "        input_ids, labels = [], []\n",
    "        for len_of_d, d in sorted(zip(len_list, batch_data), key=lambda x: -x[0]):\n",
    "            pad_len = batch_max_len - len_of_d  # 计算需要填充的长度\n",
    "            # 添加填充，并确保数据长度不超过最大长度限制\n",
    "            ids = d['input_ids'] + [self.pad_token_id] * pad_len\n",
    "            label = d['labels'] + [self.ignore_label_id] * pad_len\n",
    "            if batch_max_len > self.max_length:\n",
    "                ids = ids[:self.max_length]\n",
    "                label = label[:self.max_length]\n",
    "            input_ids.append(torch.LongTensor(ids))\n",
    "            labels.append(torch.LongTensor(label))\n",
    "\n",
    "        # 将处理后的数据堆叠成一个tensor\n",
    "        input_ids = torch.stack(input_ids)\n",
    "        labels = torch.stack(labels)\n",
    "\n",
    "        return {'input_ids': input_ids, 'labels': labels}\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "478eec3d-3b00-4d67-9bdc-88535c10ca27",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 准备数据整理器\n",
    "data_collator = DataCollatorForChatGLM(pad_token_id=tokenizer.pad_token_id)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "951d4a53-ed02-4a2f-a6be-1b95052786ce",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "d0ad8947-cfb8-47b0-90fd-5a8c93eed85a",
   "metadata": {},
   "source": [
    "## 加载模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "102589d4-f07c-4952-abf2-42eed291c01d",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading checkpoint shards: 100%|██████████| 7/7 [00:05<00:00,  1.40it/s]\n",
      "You are using an old version of the checkpointing format that is deprecated (We will also silently ignore `gradient_checkpointing_kwargs` in case you passed it).Please update to the new format on your modeling file. To use the new format, you need to completely remove the definition of the method `_set_gradient_checkpointing` in your model.\n"
     ]
    }
   ],
   "source": [
    "from transformers import AutoModel, BitsAndBytesConfig\n",
    "\n",
    "_compute_dtype_map = {\n",
    "    'fp32': torch.float32,\n",
    "    'fp16': torch.float16,\n",
    "    'bf16': torch.bfloat16\n",
    "}\n",
    "\n",
    "# QLoRA 量化配置\n",
    "q_config = BitsAndBytesConfig(load_in_4bit=True,\n",
    "                              bnb_4bit_quant_type='nf4',\n",
    "                              bnb_4bit_use_double_quant=True,\n",
    "                              bnb_4bit_compute_dtype=_compute_dtype_map['bf16'])\n",
    "# 加载量化后模型\n",
    "model = AutoModel.from_pretrained(model_name_or_path,\n",
    "                                  quantization_config=q_config,\n",
    "                                  device_map='auto',\n",
    "                                  trust_remote_code=True,\n",
    "                                  revision='b098244')\n",
    "\n",
    "model.supports_gradient_checkpointing = True  \n",
    "model.gradient_checkpointing_enable()\n",
    "model.enable_input_require_grads()\n",
    "\n",
    "model.config.use_cache = False  # silence the warnings. Please re-enable for inference!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "28ce4c2d-d2a3-4e0a-a6a8-b66c8901b6a4",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "You are using an old version of the checkpointing format that is deprecated (We will also silently ignore `gradient_checkpointing_kwargs` in case you passed it).Please update to the new format on your modeling file. To use the new format, you need to completely remove the definition of the method `_set_gradient_checkpointing` in your model.\n"
     ]
    }
   ],
   "source": [
    "from peft import TaskType, LoraConfig, get_peft_model, prepare_model_for_kbit_training\n",
    "from peft.utils import TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING\n",
    "\n",
    "kbit_model = prepare_model_for_kbit_training(model)\n",
    "target_modules = TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING['chatglm']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "5f5c327f-c34d-4c22-a72f-7a40757514b6",
   "metadata": {},
   "outputs": [],
   "source": [
    "lora_config = LoraConfig(\n",
    "    target_modules=target_modules,\n",
    "    r=lora_rank,\n",
    "    lora_alpha=lora_alpha,\n",
    "    lora_dropout=lora_dropout,\n",
    "    bias='none',\n",
    "    inference_mode=False,\n",
    "    task_type=TaskType.CAUSAL_LM\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "8c90d7ca-de06-416c-971d-b416e0f52ebd",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "trainable params: 3,899,392 || all params: 6,247,483,392 || trainable%: 0.06241540401681151\n"
     ]
    }
   ],
   "source": [
    "qlora_model = get_peft_model(kbit_model, lora_config)\n",
    "qlora_model.print_trainable_parameters()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "26f3c6c2-28b5-4d2d-b611-a4c3bf90f133",
   "metadata": {},
   "source": [
    "### QLoRA 微调模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "1f4233fa-adee-4fa2-9893-bb149636208b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import datetime\n",
    "\n",
    "timestamp = datetime.datetime.now().strftime(\"%Y%m%d_%H%M%S\")\n",
    "\n",
    "train_epochs = 3\n",
    "output_dir = f\"models/{model_name_or_path}-epoch{train_epochs}-{timestamp}\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "4f07e78f-8ff8-439a-8728-5b18b6a7626d",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import TrainingArguments, Trainer\n",
    "\n",
    "training_args = TrainingArguments(\n",
    "    output_dir=output_dir,                            # 输出目录\n",
    "    per_device_train_batch_size=8,                     # 每个设备的训练批量大小\n",
    "    gradient_accumulation_steps=1,                     # 梯度累积步数\n",
    "    learning_rate=1e-3,                                # 学习率\n",
    "    num_train_epochs=train_epochs,                     # 训练轮数\n",
    "    lr_scheduler_type=\"linear\",                        # 学习率调度器类型\n",
    "    warmup_ratio=0.1,                                  # 预热比例\n",
    "    logging_steps=1,                                 # 日志记录步数\n",
    "    save_strategy=\"steps\",                             # 模型保存策略\n",
    "    save_steps=10,                                    # 模型保存步数\n",
    "    optim=\"adamw_torch\",                               # 优化器类型\n",
    "    fp16=True,                                        # 是否使用混合精度训练\n",
    ")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "12d86b61-0cf3-4b87-b866-3e81f58ad064",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/root/miniconda3/envs/langchain/lib/python3.10/site-packages/accelerate/accelerator.py:436: FutureWarning: Passing the following arguments to `Accelerator` is deprecated and will be removed in version 1.0 of Accelerate: dict_keys(['dispatch_batches', 'split_batches']). Please pass an `accelerate.DataLoaderConfiguration` instead: \n",
      "dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False)\n",
      "  warnings.warn(\n",
      "Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n"
     ]
    }
   ],
   "source": [
    "trainer = Trainer(\n",
    "        model=qlora_model,\n",
    "        args=training_args,\n",
    "        train_dataset=tokenized_dataset,\n",
    "        data_collator=data_collator\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "3addcc9b-a23f-4b1f-a5ae-ce0cdfb57057",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/root/miniconda3/envs/langchain/lib/python3.10/site-packages/torch/utils/checkpoint.py:460: UserWarning: torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly. The default value of use_reentrant will be updated to be False in the future. To maintain current behavior, pass use_reentrant=True. It is recommended that you use use_reentrant=False. Refer to docs for more details on the differences between the two variants.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='75' max='75' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [75/75 08:23, Epoch 3/3]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       " <tr style=\"text-align: left;\">\n",
       "      <th>Step</th>\n",
       "      <th>Training Loss</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>4.524700</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "      <td>4.565700</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3</td>\n",
       "      <td>4.443300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4</td>\n",
       "      <td>4.227600</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5</td>\n",
       "      <td>4.178500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>6</td>\n",
       "      <td>3.952400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>7</td>\n",
       "      <td>3.518400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>8</td>\n",
       "      <td>3.299400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>9</td>\n",
       "      <td>2.982700</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>10</td>\n",
       "      <td>2.803500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>11</td>\n",
       "      <td>2.508900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>12</td>\n",
       "      <td>2.330200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>13</td>\n",
       "      <td>2.009400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>14</td>\n",
       "      <td>2.080200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>15</td>\n",
       "      <td>1.516700</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>16</td>\n",
       "      <td>1.294200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>17</td>\n",
       "      <td>1.337500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>18</td>\n",
       "      <td>0.910700</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>19</td>\n",
       "      <td>0.679400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>20</td>\n",
       "      <td>0.672700</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>21</td>\n",
       "      <td>0.444700</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>22</td>\n",
       "      <td>0.359500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>23</td>\n",
       "      <td>0.230200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>24</td>\n",
       "      <td>0.194200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>25</td>\n",
       "      <td>0.125300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>26</td>\n",
       "      <td>0.081900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>27</td>\n",
       "      <td>0.081400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>28</td>\n",
       "      <td>0.052400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>29</td>\n",
       "      <td>0.045100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>30</td>\n",
       "      <td>0.029500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>31</td>\n",
       "      <td>0.027700</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>32</td>\n",
       "      <td>0.022500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>33</td>\n",
       "      <td>0.017900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>34</td>\n",
       "      <td>0.017100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>35</td>\n",
       "      <td>0.017500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>36</td>\n",
       "      <td>0.012400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>37</td>\n",
       "      <td>0.015200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>38</td>\n",
       "      <td>0.013700</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>39</td>\n",
       "      <td>0.011800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>40</td>\n",
       "      <td>0.008700</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>41</td>\n",
       "      <td>0.009200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>42</td>\n",
       "      <td>0.008500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>43</td>\n",
       "      <td>0.010000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>44</td>\n",
       "      <td>0.006200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>45</td>\n",
       "      <td>0.006700</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>46</td>\n",
       "      <td>0.005400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>47</td>\n",
       "      <td>0.006400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>48</td>\n",
       "      <td>0.005900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>49</td>\n",
       "      <td>0.004800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>50</td>\n",
       "      <td>0.004600</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>51</td>\n",
       "      <td>0.003500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>52</td>\n",
       "      <td>0.003300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>53</td>\n",
       "      <td>0.003300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>54</td>\n",
       "      <td>0.003400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>55</td>\n",
       "      <td>0.003300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>56</td>\n",
       "      <td>0.003300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>57</td>\n",
       "      <td>0.002900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>58</td>\n",
       "      <td>0.002800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>59</td>\n",
       "      <td>0.003100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>60</td>\n",
       "      <td>0.003000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>61</td>\n",
       "      <td>0.002400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>62</td>\n",
       "      <td>0.002500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>63</td>\n",
       "      <td>0.002700</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>64</td>\n",
       "      <td>0.002500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>65</td>\n",
       "      <td>0.002300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>66</td>\n",
       "      <td>0.002300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>67</td>\n",
       "      <td>0.002300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>68</td>\n",
       "      <td>0.002500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>69</td>\n",
       "      <td>0.002400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>70</td>\n",
       "      <td>0.002400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>71</td>\n",
       "      <td>0.002500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>72</td>\n",
       "      <td>0.002500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>73</td>\n",
       "      <td>0.002400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>74</td>\n",
       "      <td>0.002500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>75</td>\n",
       "      <td>0.002200</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/root/miniconda3/envs/langchain/lib/python3.10/site-packages/torch/utils/checkpoint.py:460: UserWarning: torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly. The default value of use_reentrant will be updated to be False in the future. To maintain current behavior, pass use_reentrant=True. It is recommended that you use use_reentrant=False. Refer to docs for more details on the differences between the two variants.\n",
      "  warnings.warn(\n",
      "/root/miniconda3/envs/langchain/lib/python3.10/site-packages/torch/utils/checkpoint.py:460: UserWarning: torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly. The default value of use_reentrant will be updated to be False in the future. To maintain current behavior, pass use_reentrant=True. It is recommended that you use use_reentrant=False. Refer to docs for more details on the differences between the two variants.\n",
      "  warnings.warn(\n",
      "/root/miniconda3/envs/langchain/lib/python3.10/site-packages/torch/utils/checkpoint.py:460: UserWarning: torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly. The default value of use_reentrant will be updated to be False in the future. To maintain current behavior, pass use_reentrant=True. It is recommended that you use use_reentrant=False. Refer to docs for more details on the differences between the two variants.\n",
      "  warnings.warn(\n",
      "/root/miniconda3/envs/langchain/lib/python3.10/site-packages/torch/utils/checkpoint.py:460: UserWarning: torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly. The default value of use_reentrant will be updated to be False in the future. To maintain current behavior, pass use_reentrant=True. It is recommended that you use use_reentrant=False. Refer to docs for more details on the differences between the two variants.\n",
      "  warnings.warn(\n",
      "/root/miniconda3/envs/langchain/lib/python3.10/site-packages/torch/utils/checkpoint.py:460: UserWarning: torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly. The default value of use_reentrant will be updated to be False in the future. To maintain current behavior, pass use_reentrant=True. It is recommended that you use use_reentrant=False. Refer to docs for more details on the differences between the two variants.\n",
      "  warnings.warn(\n",
      "/root/miniconda3/envs/langchain/lib/python3.10/site-packages/torch/utils/checkpoint.py:460: UserWarning: torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly. The default value of use_reentrant will be updated to be False in the future. To maintain current behavior, pass use_reentrant=True. It is recommended that you use use_reentrant=False. Refer to docs for more details on the differences between the two variants.\n",
      "  warnings.warn(\n",
      "/root/miniconda3/envs/langchain/lib/python3.10/site-packages/torch/utils/checkpoint.py:460: UserWarning: torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly. The default value of use_reentrant will be updated to be False in the future. To maintain current behavior, pass use_reentrant=True. It is recommended that you use use_reentrant=False. Refer to docs for more details on the differences between the two variants.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "TrainOutput(global_step=75, training_loss=0.7437443415106585, metrics={'train_runtime': 511.2514, 'train_samples_per_second': 1.174, 'train_steps_per_second': 0.147, 'total_flos': 6105935699116032.0, 'train_loss': 0.7437443415106585, 'epoch': 3.0})"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "cc1156a8-a6c4-42f7-bbaf-a6736d51bd0f",
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer.model.save_pretrained(output_dir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0a1f764a-bbde-4862-9665-e5c0d67c640a",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
